import logging

import torch
import torch
import torch.nn as nn
from   torch.nn import CrossEntropyLoss, MSELoss

from transformers.configuration_roberta import RobertaConfig
from transformers.modeling_bert import BertEmbeddings, BertLayerNorm, BertModel, BertPreTrainedModel
from transformers.modeling_roberta import RobertaModel, RobertaLMHead


logger = logging.getLogger(__name__)

ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
    "roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-pytorch_model.bin",
    "roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-pytorch_model.bin",
    "roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-pytorch_model.bin",
    "distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-pytorch_model.bin",
    "roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-pytorch_model.bin",
    "roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-pytorch_model.bin",
}


class CMLModel(BertPreTrainedModel):
    config_class = RobertaConfig
    pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
    base_model_prefix = "roberta"

    def __init__(self, config,args=None):
        super().__init__(config)

        self.roberta = RobertaModel(config)
        self.lm_head = RobertaLMHead(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.lm_head.decoder
            
    def forward(self,input_ids=None,attention_mask=None,token_type_ids=None,
        position_ids=None,head_mask=None,inputs_embeds=None,masked_lm_labels=None,reduce=True):
        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
        sequence_output = outputs[0]
        prediction_scores = self.lm_head(sequence_output)
        outputs = (prediction_scores,)  # Add hidden states and attention if they are here

        if masked_lm_labels is not None:
            batch_size = input_ids.size(0)
            loss_fct = torch.nn.CrossEntropyLoss(reduce=reduce) # note we are extracting per element loss
            loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
            if not reduce:
                loss = torch.chunk(loss,batch_size)
                bloss = []
                for x in loss:
                    x = x.unsqueeze(0)
                    bloss.append(x)
                loss = torch.cat(bloss,dim=0)
                loss = loss.mean(dim=-1)
            outputs = (loss,) + outputs
        return outputs
    
    # context_masked, question_masked, answer_options_masked
    def score(self,ctx_inputs,question_inputs,ans_inputs,labels,label_mask):
        softmax = torch.nn.Softmax(dim=-1)
        with torch.no_grad():
            option_nums = ans_inputs.shape[1] if len(ans_inputs.shape)>2 else ans_inputs.shape[0]
            flat_ctx_inputs, flat_q_inputs, flat_ans_inputs, flat_labels = ctx_inputs.view(-1,ctx_inputs.size(-1)), \
             question_inputs.view(-1,question_inputs.size(-1)), ans_inputs.view(-1,ans_inputs.size(-1)), labels.view(-1,labels.size(-1))
            
            input_ids = torch.cat([flat_ctx_inputs,flat_q_inputs,flat_ans_inputs])
            labels = torch.cat([flat_labels]*3)
            loss = self.forward(input_ids,masked_lm_labels=labels,reduce=False)[0]
            ctx_loss,ques_loss,ans_loss = torch.chunk(loss,3)
            ctx_loss =  ctx_loss.view(-1,option_nums)
            ques_loss = ques_loss.view(-1,option_nums)
            ans_loss =  ans_loss.view(-1,option_nums)
#             print(f"{ctx_loss.shape}{ques_loss.shape}{ans_loss.shape}")
            d1 = ctx_loss*ques_loss*ans_loss
            d2 = ctx_loss*ans_loss
            d4 = ques_loss*ans_loss
            d7 = ctx_loss*ques_loss
            d3 = ans_loss
            d5 = ctx_loss
            d6 = ques_loss
        return d1,d2,d3,d4,d5,d6,d7
        
        

if __name__ == '__main__':
    model = CMLModel.from_pretrained("roberta-base").cuda()
    inp_ids = torch.randint(80,[3,10]).cuda()
    labels = torch.randint(80,[3,10]).cuda()
    
    outputs = model(input_ids=inp_ids,masked_lm_labels=labels)
    print(f"Forward Output:{outputs}")
    
    ctx_all_ids =torch.randint(100,[3,4,10]).cuda()
    qs_all_idx=torch.randint(100,[3,4,10]).cuda()
    ans_all_ids=torch.randint(100,[3,4,10]).cuda()
    labels = torch.randint(80,[3,4,10]).cuda()

    scores = model.score(ctx_all_ids,qs_all_idx,ans_all_ids,labels=labels)
    print(f"Forward Output:{scores}")